def sample_tif(args):
    import xarray as xr
    import rioxarray as rxr
    tif_path = args[0]
    lat = args[1]
    lon = args[2]
    #open tif
    ds = xr.open_dataarray(tif_path).rio.reproject("EPSG:4326")
    #create indexing data arrays
    lats = xr.DataArray(lat, dims='z')
    lons = xr.DataArray(lon, dims='z')
    # pull data 
    data = ds.sel(x = lons, y = lats, method = 'nearest')
    #turn into a list 
    data_list = list(data.values[0])
    # return data list 
    return (tif_path, data_list)

def pptnc_to_tif(pair):
    import rioxarray as rxr
    import xarray as xr

    file = pair[0]
    time_step = pair[1]
    #open the file
    nc_file = xr.open_dataset(file)
    # set spatial dimensions 
    nc_file.precip.rio.set_spatial_dims("lon", "lat", inplace=True)

    # create output string
    output =  "E:\\Tara_Fall_2019\\Senegal_Veg_Model\\Dynamic\\Senegal_ppt" + str(time_step)[:10] + ".tif"
    # select the time step & set spatial dimensions
    nc_file = nc_file.precip.sel(time = time_step).rio.set_spatial_dims("lon", "lat", inplace=True)
    # set crs
    nc_file = nc_file.rio.set_crs("epsg:4326")

    # save tif 
    nc_file.rio.to_raster(output)

def tmaxnc_to_tif(pair): 
    import rioxarray as rxr
    import xarray as xr

    file = pair[0]
    time_step = pair[1]
    #open the file
    nc_file = xr.open_dataset(file)
    # create output string
    output = file[:-7] + str(time_step)[:10] + ".tif"
    # select the time step 
    nc_file = nc_file.tmax.sel(time = time_step)
    # set crs
    nc_file = nc_file.rio.set_crs("epsg:4326")
    # save tif 
    nc_file.rio.to_raster(output)
    
def tminnc_to_tif(pair): 
    import rioxarray as rxr
    import xarray as xr

    file = pair[0]
    time_step = pair[1]
    #open the file
    nc_file = xr.open_dataset(file)
    # create output string
    output = file[:-7] + str(time_step)[:10] + ".tif"
    # select the time step 
    nc_file = nc_file.tmin.sel(time = time_step)
    # set crs
    nc_file = nc_file.rio.set_crs("epsg:4326")
    # save tif 
    nc_file.rio.to_raster(output)

    # input a date, get all date files for precip and temp 
def get_date_files(savi_date): 
    import datetime
    import os

    # get all dates 16 days prior to date 
    window_16 = [savi_date - datetime.timedelta(days=x) for x in range(16)]
    str_16 = [str(d)[:10] for d in window_16]
    # get all dates 32 days prior to date
    window_32 = [savi_date - datetime.timedelta(days=x) for x in range(32)]
    str_32 = [str(d)[:10] for d in window_32]
    # get all dates 48 days prior to date 
    window_48 = [savi_date - datetime.timedelta(days=x) for x in range(48)]
    str_48 = [str(d)[:10] for d in window_48]
    # get all dates 64 days prior to date 
    window_64 = [savi_date - datetime.timedelta(days=x) for x in range(64)]
    str_64 = [str(d)[:10] for d in window_64]

    # get precip files 
    gsmap_path = "F:\\Senegal_Veg_Model\\Dynamic\\Senegal_ppt"
    gsmap_files = os.listdir(gsmap_path)
    ppt_files = [os.path.join(gsmap_path, file) for file in gsmap_files if ".tif" in file]
    # get temp files
    NOAA_path = "F:\\Senegal_Veg_Model\\Dynamic\\Senegal_temp\\Tavg"
    NOAA_files = os.listdir(NOAA_path)
    temp_files = [os.path.join(NOAA_path, file) for file in NOAA_files if "tavg.TIF" in file]
    # get z score file 
    vi_path = "F:\\Senegal_Veg_Model\\Dynamic\\EVI_mosaic"
    vi_files = os.listdir(vi_path)
    savi_files = [os.path.join(vi_path, file) for file in vi_files]


    # get all precip files 
    precip_16 = [file for file in ppt_files if any(substring in file for substring in str_16)]
    precip_32 = [file for file in ppt_files if any(substring in file for substring in str_32)]
    precip_48 = [file for file in ppt_files if any(substring in file for substring in str_48)]
    precip_64 = [file for file in ppt_files if any(substring in file for substring in str_64)]

    # get all temp files 
    temp_16 = [file for file in temp_files if any(substring in file for substring in str_16)]
    temp_32 = [file for file in temp_files if any(substring in file for substring in str_32)]
    temp_48 = [file for file in temp_files if any(substring in file for substring in str_48)]
    temp_64 = [file for file in temp_files if any(substring in file for substring in str_64)]

    return (precip_16, precip_32, precip_48, precip_64, temp_16, temp_32, temp_48, temp_64)

    # get sum of rasters and sample 
def sum_and_sample(args): 
    import xarray as xr
    import rioxarray as rxr

    tif_paths = args[0]
    lat = args[1]
    lon = args[2]
    variable = args[3]
    length = 0
    if len(tif_paths) == 16 or len(tif_paths) == 15: 
        length = 16
    if len(tif_paths) == 32 or len(tif_paths) == 31: 
        length = 32
    if len(tif_paths) == 48 or len(tif_paths) == 47: 
        length = 48
    if len(tif_paths) == 64 or len(tif_paths) == 63: 
        length = 64
    #open tifs
    ds = [xr.open_dataarray(tif_path).rio.reproject("EPSG:4326") for tif_path in tif_paths]
    # sum tifs 
    tif_sum = sum(ds)
    #create indexing data arrays
    lats = xr.DataArray(lat, dims='z')
    lons = xr.DataArray(lon, dims='z')
    # pull data 
    data = tif_sum.sel(x = lons, y = lats, method = 'nearest')
    #turn into a list 
    data_list = list(data.values[0])

    # return data list 
    return ((str(length) + "_day_sum_" + variable), data_list)

    # get mean of rasters and sample - arguments = paths of tifs to calculate and sample, lat, lon, variable string
def mean_and_sample(args): 
    import xarray as xr
    import rioxarray as rxr

    tif_paths = args[0]
    lat = args[1]
    lon = args[2]
    variable = args[3]
    length = 0
    if len(tif_paths) == 16 or len(tif_paths) == 15: 
        length = 16
    if len(tif_paths) == 32 or len(tif_paths) == 31: 
        length = 32
    if len(tif_paths) == 48 or len(tif_paths) == 47: 
        length = 48
    if len(tif_paths) == 64 or len(tif_paths) == 63: 
        length = 64

    #open tifs
    ds = [xr.open_dataarray(tif_path).rio.reproject("EPSG:4326") for tif_path in tif_paths]
    # sum tifs 
    tif_mean = sum(ds)/len(tif_paths)
    #create indexing data arrays
    lats = xr.DataArray(lat, dims='z')
    lons = xr.DataArray(lon, dims='z')
    # pull data 
    data = tif_mean.sel(x = lons, y = lats, method = 'nearest')
    #turn into a list 
    data_list = list(data.values[0])
    # return data list 
    return ((str(length) + "_day_mean_" + variable), data_list)

# get stdv of rasters and sample - arguments = paths of tifs to calculate and sample, lat, lon, variable string
def stdv_and_sample(args): 
    import xarray as xr
    import rioxarray as rxr
    import numpy as np

    tif_paths = args[0]
    lat = args[1]
    lon = args[2]
    variable = args[3]
    length = 0
    if len(tif_paths) == 16 or len(tif_paths) == 15: 
        length = 16
    if len(tif_paths) == 32 or len(tif_paths) == 31: 
        length = 32
    if len(tif_paths) == 48 or len(tif_paths) == 47: 
        length = 48
    if len(tif_paths) == 64 or len(tif_paths) == 63: 
        length = 64

    #open tifs
    ds = [xr.open_dataarray(tif_path).rio.reproject("EPSG:4326") for tif_path in tif_paths]
    # sum tifs 
    tif_mean = sum(ds)/len(tif_paths)
    sq_dist = [(tif - tif_mean)**2 for tif in ds]
    sum_sq_dist = sum(sq_dist)
    tif_stdv = np.sqrt(sum_sq_dist/len(tif_paths))

    #create indexing data arrays
    lats = xr.DataArray(lat, dims='z')
    lons = xr.DataArray(lon, dims='z')
    # pull data 
    data = tif_stdv.sel(x = lons, y = lats, method = 'nearest')
    #turn into a list 
    data_list = list(data.values[0])
    # return data list 
    return ((str(length) + "_day_stdv_" + variable), data_list)

# In Landsat 8-9, EVI = 2.5 * ((Band 5 – Band 4) / (Band 5 + 6 * Band 4 – 7.5 * Band 2 + 1))
#### CALCULATE EVI FOR EACH IMAGE - feed in a unique file id wwith tile and date 
def get_EVI(unique_id): 
    import rasterio
    import numpy as np
    
    # directory
    directory = "E:\\Tara_Fall_2019\\Senegal_Veg_Model\\Dynamic\\Senegal_Landsat\\Bulk_Order_Senegal\\Peanut_Basin"

    #define file paths 
    b4file_path = directory + unique_id + "_02_T1_SR_B4.TIF"
    b5file_path = directory + unique_id + "_02_T1_SR_B5.TIF"
    b2file_path = directory + unique_id + "_02_T1_SR_B2.TIF"
    
    try:
        # open files
        b4 = rasterio.open(b4file_path)
        b5 = rasterio.open(b5file_path)
        b2 = rasterio.open(b2file_path)
        print ("files open.")
    
        # change dtype 

        red = b4.read(1)
        nir = b5.read(1)
        blue = b2.read(1)

        # print (b4.meta, b5.meta, b2.meta)
        
        kwargs = b2.meta
        kwargs.update(
        dtype=rasterio.float64,
        count=1,
        nodata = -19999)
        # calculate EVI 
        evi = np.where(
            nir == 0.0,
            -19999,
            (2.5*((nir-red)/(nir + (6*red) - (7.5*blue) + 1)))
            )
        # write EVI raster file
        evi_new_path = "E:\\Tara_Fall_2019\\Senegal_Veg_Model\\Dynamic\\Senegal_Landsat\\Bulk_Order_Senegal\\Peanut_Basin_EVI\\" + unique_id + "_EVI.TIF"
        eviImage = rasterio.open(evi_new_path, 'w', **kwargs)

        eviImage.write_band(1, evi)
        eviImage.close()
    except: 
        print ("Corrupted file: ", unique_id)


def get_VI_zscore(arg):
    
    # import
    import xarray as xr
    import datetime
    import os 
    import rioxarray as rxr
    import numpy as np
    import rasterio as rio

    # try:
    # pull from the argument passed
    meta_files = arg[0]
    path = arg[1]
    kwargs = rio.open(path).meta
    tile = arg[2]
    date = arg[3]
    
    # open files 
    meta_opened = [xr.open_dataarray(file) for file in meta_files]
    # get the files needed which are within 20 days of date, reproject to EPSG4326 
    zscore_files = [f.rio.reproject("EPSG:4326") for f in meta_opened]
    
    # pull the main file to calculate z score of 
    rxr_MAIN = zscore_files[0]
    # get the numpy array of the main file 
    MAIN = rxr_MAIN.to_numpy()[0]
    
    # make all files the same extent, pull the numpy array from the rxr file
    repr_files = [f.rio.reproject_match(rxr_MAIN) for f in zscore_files] # this works
    matched = []
    for f in repr_files: 
        matched.append(f.assign_coords({"x": rxr_MAIN.x, "y": rxr_MAIN.y})) 
    
    opened = []
    # clean up the nodata values
    for f in matched:
        opened.append(f.where(f != -19999))
    print ("files opened and cleaned")
    
    #take the average over all rasters
    numerator = sum(opened)
    
#     plt.imshow(numerator)
    denominator = len(opened)
    mean = (numerator / denominator)
    
    sqd_err = [(f-mean)**2 for f in opened]
    sum_sqd_err = sum(sqd_err)
    std = np.sqrt(sum_sqd_err / len(opened))
    
    
    # z score calculation
    z_score = ((MAIN - mean) / std)
    print ("z score calculated")
    
#     write SAVI raster file
    new_path = "E:\\bulk_download_USGS\\Bulk_Order_Senegal\\Peanut_Basin_EVI_zscore\\" + tile + "_" + date + "_zscore.TIF"
    # update attributes
    z_score.rio.write_crs(opened[0].rio.crs, inplace=True)
    z_score.rio.update_attrs(opened[0].attrs, inplace=True)
    z_score.rio.update_encoding(opened[0].encoding, inplace=True)
    z_score.rio.to_raster(new_path)
    # except:
    #     print (path, "corrupted.")